import torch.nn as nn
from .utils.noisytransformers import NoisyTransformerClassifier

from .utils.tokenizer import Tokenizer, TextTokenizer
from .utils.embedder import Embedder
from torch.nn import Parameter
import torch

__all__ = ['noisyvit_lite_7']


class NoisyViTLite(nn.Module):
    def __init__(self,
                 img_size=224,
                 embedding_dim=768,
                 n_input_channels=3,
                 patch_size=16,                  
                 add_noise=0.08, 
                 mult_noise=0.02,
                 *args, **kwargs):
        super(NoisyViTLite, self).__init__()
        assert img_size % patch_size == 0, f"Image size ({img_size}) has to be" \
                                           f"divisible by patch size ({patch_size})"
        self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
                                   n_output_channels=embedding_dim,
                                   kernel_size=patch_size,
                                   stride=patch_size,
                                   padding=0,
                                   max_pool=False,
                                   activation=None,
                                   n_conv_layers=1,
                                   conv_bias=True)

        self.classifier = NoisyTransformerClassifier(
            sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
                                                           height=img_size,
                                                           width=img_size),
            embedding_dim=embedding_dim,
            seq_pool=False,
            dropout_rate=0.1,
            attention_dropout=0.,
            stochastic_depth=0.,
            *args, **kwargs)

        self.dummy_param = Parameter(torch.empty(0))
        self.add_noise = add_noise
        self.mult_noise = mult_noise    

    def forward(self, x, lam=1.0, k=-1, add_noise=0.0, mult_noise=0.0):
        x = self.tokenizer(x)
        return self.classifier(x, lam=lam, k=k, add_noise_level=add_noise, mult_noise_level=mult_noise)




def _noisyvit_lite(num_layers, num_heads, mlp_ratio, embedding_dim,
              patch_size=4, *args, **kwargs):
    return NoisyViTLite(num_layers=num_layers,
                   num_heads=num_heads,
                   mlp_ratio=mlp_ratio,
                   embedding_dim=embedding_dim,
                   patch_size=patch_size,
                   *args, **kwargs)



def noisyvit_lite_7(*args, **kwargs):
    return _noisyvit_lite(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
                     *args, **kwargs)


